import numpy as np
import torch
from torch.optim import Adam
import safety_gym
import gym
import time
import  core
from utils.logx import EpochLogger
from utils.mpi_pytorch import setup_pytorch_for_mpi, sync_params, mpi_avg_grads
from utils.mpi_tools import mpi_fork, mpi_avg, proc_id, mpi_statistics_scalar, num_procs
from torch.nn.functional import softplus
torch.autograd.set_detect_anomaly(True)
import csv


def ppo(env_fn, actor_critic=core.MLPActorCritic, ac_kwargs=dict(),
        seed=0, 
        steps_per_epoch=4000, epochs=50, gamma=0.99, clip_ratio=0.2, pi_lr=3e-4,
        vf_lr=1e-3, train_pi_iters=80, train_v_iters=80, lam=0.97, max_ep_len=1000,
        target_kl=0.01, logger_kwargs=dict(), save_freq=1,
        storage_intest_number_array = [25000, 25000],
        intest_trajectory_step = 60,
        agent_checkpoint_mode = "error"
        ):


    # Special function to avoid certain slowdowns from PyTorch + MPI combo.
    setup_pytorch_for_mpi()
    print("SEED!!!!!!!!!!", seed)
    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    # Random seed
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed(seed)
    env = env_fn()
    n_observations = env.observation_space.shape[0]

    # Instantiate environment
    env = env_fn()
    obs_dim = env.observation_space.shape #60
    act_dim = env.action_space.shape #2

    # Set up experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())

    # Path setting
    ppo_checkpoint_path = agent_checkpoint_mode+"/checkpoint/"+agent_checkpoint_mode+"_10000.pt"
    # Logger,Seed,Device,Seed,Env setting

    
    # IT setting
    collect_intest_cycle = 0
    n_unlabeled_data = 0
    data_list = list()
    label_list = list()
    counter_list = list()
    # DQN setting


    # IT setting
    storage_intest_number = storage_intest_number_array[0] + storage_intest_number_array[1]
    storage_intest_state_list = np.empty([0, n_observations+2])
    storage_intest_label_list = np.empty([0, 1])
    # ETC setting
    backup_epoch = 0
    jinjja_context = [0, 0]
    context_real = [0, 0]
    
    try:
        # Create actor-critic module
        # ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
        # ac = actor_critic(env.observation_space, gym.spaces.Discrete(n_translation*n_rotation), **ac_kwargs)
        ac = torch.load(ppo_checkpoint_path)
        print("    LOAD CHECKPOINT %s" %(ppo_checkpoint_path))
        log_std = -2 * np.ones(act_dim[0], dtype=np.float32) 
        ac.pi.log_std.data = torch.as_tensor(log_std) 
    except Exception as e:
        print(e)
        print("    CANNOT LOAD CHECKPOINT")
        exit()
    
    # Count variables
    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n'%var_counts)

    def extract_storage_intest_index():
        nonlocal storage_intest_label_list, storage_intest_state_list, context_real
        safe_intest_index_list = list()
        unsafe_intest_index_list = list()
        for j in range(context_real[0]+context_real[1]):
            if storage_intest_label_list[j] == 1: unsafe_intest_index_list.append(j)
            else: safe_intest_index_list.append(j)
        assert len(safe_intest_index_list) == context_real[0]
        assert len(unsafe_intest_index_list) == context_real[1]

        default_path = agent_checkpoint_mode+"/intest/parallel/"+"intest_obs"+str(seed)+".csv"
        file = open(default_path, "w")
        writer = csv.writer(file)
        writer.writerow([context_real[0]])
        writer.writerow([context_real[1]])
        for j in range(context_real[0]):
            writer.writerow(storage_intest_state_list[safe_intest_index_list[j]])
        for j in range(context_real[1]):
            writer.writerow(storage_intest_state_list[unsafe_intest_index_list[j]])
        writer.writerow("")
        file.close()     
 
    
    # Prepare for interaction with environment
    print("    START NEW SESSION!", epochs)
    start_time = time.time()
    o, ep_ret,ep_cret, ep_len = env.reset(), 0, 0, 0
    intest_rejected_count = 0
    total_step = 0

    # Main loop: collect experience in env and update/log each epoch
    temp_time = 0
    for epoch in range(epochs):
        for t in range(local_steps_per_epoch):            
            a, v, vc, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
            
            state_reshape = o.reshape((1,)+o.shape)
            state_action_reshape = np.concatenate((state_reshape, np.array([a])), axis=1)
            
            next_o, r, d, info = env.step(a)
            total_step += 1
            c = info['cost']
            goal_checking = info.get('goal_met')
            if c > 0: safety_label = 1
            else: safety_label = 0

            collect_intest_cycle+=1
            if collect_intest_cycle >= 10:
                collect_intest_cycle = 0
                if safety_label == 0:
                    label_list.append(safety_label)
                    data_list.append(state_action_reshape)
                    counter_list.append(0)
                    n_unlabeled_data+=1
                else: 
                    intest_rejected_count += 1
                    # print("REJECTED: ", intest_rejected_count)
            
            ep_ret += r
            ep_cret += c
            ep_len += 1
            
            # Update obs (critical!)
            o = next_o

            if goal_checking:
                while n_unlabeled_data > 0:
                    pop_label = label_list.pop(0)
                    pop_state = data_list.pop(0)
                    counter_list.pop(0)
                    n_unlabeled_data -= 1
                    intest_size_check = context_real[0] < storage_intest_number_array[0]
                    jinjja_context[0] +=1
                    if not intest_size_check: pass
                    else:
                        storage_intest_label_list = np.append(storage_intest_label_list, np.array([0]).reshape((1,1)), axis=0)
                        storage_intest_state_list = np.append(storage_intest_state_list, pop_state, axis=0)
                        context_real[0] += 1
            elif safety_label == 1:
                while n_unlabeled_data > 0:
                    pop_label = label_list.pop(0)
                    pop_state = data_list.pop(0)
                    counter_list.pop(0)
                    n_unlabeled_data -= 1
                    intest_size_check = context_real[1] < storage_intest_number_array[1]
                    jinjja_context[1] +=1
                    if not intest_size_check: pass
                    else:
                        storage_intest_label_list = np.append(storage_intest_label_list, np.array([1]).reshape((1,1)), axis=0)
                        storage_intest_state_list = np.append(storage_intest_state_list, pop_state, axis=0)
                        context_real[1] += 1
                o, _, _, _ = env.reset(), 0, 0, 0
            elif n_unlabeled_data <= 0: pass
            else:
                for i in range(n_unlabeled_data): counter_list[i] += 1
                if counter_list[0] == intest_trajectory_step:
                    pop_label = label_list.pop(0)
                    assert pop_label == 0
                    pop_state = data_list.pop(0)
                    counter_list.pop(0)
                    n_unlabeled_data -= 1
                    intest_size_check = context_real[pop_label] < storage_intest_number_array[pop_label]
                    jinjja_context[pop_label] +=1
                    if not intest_size_check: pass
                    else:
                        storage_intest_label_list = np.append(storage_intest_label_list, np.array([pop_label]).reshape((1,1)), axis=0)
                        storage_intest_state_list = np.append(storage_intest_state_list, pop_state, axis=0)
                        context_real[pop_label] += 1


            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t==local_steps_per_epoch-1

            if terminal or epoch_ended:
                if epoch_ended and not(terminal) and proc_id() == 0:
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len, EpCost=ep_cret)
                o, ep_ret, ep_cret, ep_len = env.reset(), 0, 0, 0
                if (n_unlabeled_data > 0):
                    for _ in range(n_unlabeled_data):
                        label_list.pop(0)
                        data_list.pop(0)
                        counter_list.pop(0)
                    n_unlabeled_data = 0
                if proc_id() == 0:
                    with open(agent_checkpoint_mode+"/intest/intest_timelog.txt", "a") as f: f.write("context: ["+str(jinjja_context[0])+','+str(jinjja_context[1])+'] ['
                                                                                                     +str(context_real[0])+','+str(context_real[1])+"]\n")
                if storage_intest_number == context_real[0] + context_real[1]:                
                    extract_storage_intest_index()
                    print("FINISH")
                    exit() #####

        # extract_storage_intest_index()

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs-1):
            logger.save_state({'env': env}, None)

        # if proc_id()==0:
        #     # Log info about epoch
        #     logger.log_tabular('Epoch', epoch)
        #     logger.log_tabular('EpRet', with_min_and_max=True)
        #     logger.log_tabular('EpCost',with_min_and_max=True)
        #     logger.log_tabular('EpLen', average_only=True)
        #     logger.log_tabular('TotalEnvInteracts', (epoch+1)*steps_per_epoch)
        #     logger.log_tabular('Time', time.time()-start_time)
        #     logger.dump_tabular()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Safexp-PointGoal1-v0')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--cpu', type=int, default=200) 
    parser.add_argument('--steps', type=int, default=4000)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='ppo_point_collectinest')
    parser.add_argument('--mode', type=str, default='error')
    args = parser.parse_args()
    print(args.mode)
    mpi_fork(args.cpu)  # run parallel code with mpi

    from utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
    num_steps = 1e16
    steps_per_epoch = 1000*args.cpu
    epochs = int(num_steps / steps_per_epoch)
    ppo(lambda : gym.make(args.env), actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(hidden_sizes=[args.hid]*args.l), gamma=args.gamma, 
        seed=proc_id(), steps_per_epoch=steps_per_epoch, epochs=epochs,
        logger_kwargs=logger_kwargs, agent_checkpoint_mode=args.mode)
